package com.alks.common.utils.excelUtils;
import cn.hutool.core.collection.CollUtil;
import com.alibaba.excel.metadata.Head;
import com.alibaba.excel.write.merge.AbstractMergeStrategy;
import org.apache.poi.ss.usermodel.Cell;
import org.apache.poi.ss.usermodel.Sheet;
import org.apache.poi.ss.util.CellRangeAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;

public class CustomMergeStrategy extends AbstractMergeStrategy {

	/**
	 * 分组，每几行合并一次
	 */
	private List<List<Integer>> mergeColDataGroupCountList;

	/**
	 * 目标合并列index
	 */
	private List<Integer> targetColumnIndex;
	/**
	 * 	需要开始合并单元格的首行index
 	 */
	private Integer rowIndex;

	/**
	 * 	mergeColDataList为待合并目标列的值
 	 */
	public CustomMergeStrategy(List<List<String>> mergeColDataList, List<Integer> targetColumnIndex) {
		this.mergeColDataGroupCountList = getGroupCountList(mergeColDataList);
		this.targetColumnIndex = targetColumnIndex;
	}


	@Override
	protected void merge(Sheet sheet, Cell cell, Head head, Integer relativeRowIndex) {

		if (null == rowIndex) {
			rowIndex = cell.getRowIndex();
		}
		// 仅从首行以及目标列的单元格开始合并，忽略其他
		if (cell.getRowIndex() == rowIndex && targetColumnIndex.contains(cell.getColumnIndex())) {
			//找到对应的需要合并的列
			AtomicInteger i = new AtomicInteger(0);
			Optional<Integer> first = targetColumnIndex.stream().filter(col -> {
				i.getAndIncrement();
				return col == cell.getColumnIndex();
			}).findFirst();
			mergeGroupColumn(sheet, first.get());
		}
	}

	private void mergeGroupColumn(Sheet sheet, Integer index) {
		int rowCount = rowIndex;
		for (Integer count : mergeColDataGroupCountList.get(index)) {
			if (count == 1) {
				rowCount += count;
				continue;
			}
			// 合并单元格
			CellRangeAddress cellRangeAddress = new CellRangeAddress(rowCount, rowCount + count - 1,
					targetColumnIndex.get(index), targetColumnIndex.get(index));
			sheet.addMergedRegionUnsafe(cellRangeAddress);
			rowCount += count;
		}
	}

	/**
	 * 	该方法将目标列根据值是否相同连续可合并，存储可合并的行数
 	 */
	private List<List<Integer>> getGroupCountList(List<List<String>> exportDataList) {
		if (CollUtil.isEmpty(exportDataList)) {
			return new ArrayList<>();
		}
		List<List<Integer>> groupCountListList = new ArrayList<>();
		exportDataList.forEach(dataList->{
			List<Integer> groupCountList = new ArrayList<>();
			int count = 1;
			for (int i = 1; i < dataList.size(); i++) {
				if (dataList.get(i).equals(dataList.get(i - 1))) {
					count++;
				} else {
					groupCountList.add(count);
					count = 1;
				}
			}
			// 处理完最后一条后
			groupCountList.add(count);
			groupCountListList.add(groupCountList);
		});
		return groupCountListList;
	}
}
